import torch
from transformers.generation import StoppingCriteria


# Customized stopping criterion for text generation process
class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops, encounters=1):
        super().__init__()
        self.stops = stops
        self.encounters = encounters

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if stop.tolist() == input_ids[0][-len(stop):].tolist():
                return True
        return False
